/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.io.*;
import java.nio.channels.*;
import java.util.*;

import edu.berkeley.nlp.util.*;
import edu.umd.clip.lm.factors.FactorTuple;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.model.data.*;
import edu.umd.clip.lm.ngram.Ngram;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class PruneTrainingData {
	public static class Options {
        @Option(name = "-data-files", required = true, usage = "Training data files (comma-separated)")
		public String dataFiles;
        @Option(name = "-config", required = true, usage = "XML config file")
		public String config;
        @Option(name = "-order", required = false, usage = "order of n-grams to prune (default 3)")
		public int order = 3;
	}

	/**
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
        OptionParser optParser = new OptionParser(Options.class);
        Options opts = (Options) optParser.parse(args, true);

		Experiment.initialize(opts.config);
		Experiment experiment = Experiment.getInstance();
		experiment.buildPrefixes();

		long overtMask = experiment.getTupleDescription().getOvertFactorsMask();

		HashSet<Ngram> singletons = new HashSet<Ngram>(1000000);
		HashSet<Ngram> multiples = new HashSet<Ngram>(1000000);
		
		String [] fileNames = opts.dataFiles.split(",");
		
		// collect singleton stats
		for(String fileName : fileNames) {
			ReadableByteChannel channel = new FileInputStream(fileName).getChannel();
			TrainingDataReader reader = new OnDiskTrainingDataReader(channel);
			ReadableTrainingData inputData = new ReadableTrainingData(reader);

			while(inputData.hasNext()) {
				TrainingDataBlock block = inputData.next();
				
				for(ContextFuturesPair pair : block) {
					long [] context = pair.getContext().data;
					for(int i=0; i<context.length; ++i) {
						context[i] &= overtMask;
					}
					
					for(TupleCountPair tc : pair.getFutures()) {
						Ngram ngram = Ngram.makeNgram(opts.order, context, tc.tuple & overtMask);

						if (tc.count > 1 || singletons.contains(ngram) || multiples.contains(ngram)) {
							singletons.remove(ngram);
							multiples.add(ngram);
						} else {
							singletons.add(ngram);
						}
					}
					
				}
			}
			channel.close();
			System.out.printf("File %s scanned: %d singletons, %d multiples\n", fileName, singletons.size(), multiples.size());
		}
		multiples.clear();
		
		// rename files
		for(String fileName : fileNames) {
			File f = new File(fileName);
			f.renameTo(new File(fileName + ".bak"));
		}
		
		// filter files
		for(String fileName : fileNames) {
			ReadableByteChannel channel = new FileInputStream(fileName + ".bak").getChannel();
			TrainingDataReader reader = new OnDiskTrainingDataReader(channel);
			ReadableTrainingData inputData = new ReadableTrainingData(reader);
		
			FileChannel output = new FileOutputStream(fileName).getChannel();
			TrainingDataWriter writer = new OnDiskTrainingDataWriter(output);
			WritableTrainingData outputData = writer.createData();
			
			TrainingDataBlock outputBlock = new TrainingDataBlock();
			
			long pruned = 0;
			while(inputData.hasNext()) {
				TrainingDataBlock block = inputData.next();
				
				for(ContextFuturesPair pair : block) {
					long [] context = Arrays.copyOf(pair.getContext().data, pair.getContext().data.length);
					for(int i=0; i<context.length; ++i) {
						context[i] &= overtMask;
					}
					
					TupleCountPair[] newFutures = new TupleCountPair[pair.getFutures().length];
					int nrNewFutures = 0;
					for(TupleCountPair tc : pair.getFutures()) {
						if (tc.count > 1) {
							newFutures[nrNewFutures++] = tc;
							continue;
						}
						
						Ngram ngram = Ngram.makeNgram(opts.order, context, tc.tuple & overtMask);
						
						if (!singletons.contains(ngram)) {
							newFutures[nrNewFutures++] = tc;
						} else {
							pruned += 1;
						}
					}
					if (nrNewFutures == 0) continue;
					
					if (nrNewFutures < newFutures.length) {
						newFutures = Arrays.copyOf(newFutures, nrNewFutures);
					}
					
					if (!outputBlock.add(new ContextFuturesPair(pair.getContext(), newFutures))) {
						outputData.add(outputBlock);
						outputBlock = new TrainingDataBlock();
						outputBlock.add(pair);
					}
					
				}
			}
			channel.close();

			if (outputBlock.hasData()) {
				outputData.add(outputBlock);
			}
			outputData.finish();
			
			output.close();
			System.out.printf("File %s done, %d singletons pruned\n", fileName, pruned);
		}
	}

}
